% Auxiliary Function: gen_stats
%   Description: This function computes statistics for simulation results
%   Inputs:
%       beta: Estimates
%       beta0: True parameter values
%       se: Estimated standard errors
%   Outputs:
%       stats: Summary statistics

function stats = gen_stats(beta, beta0, se)
    % Mean & median & standard deviation & robust standard deviation of estimates
    stats.mean = mean(beta);
    stats.med = median(beta);
    stats.std = std(beta,1);
    stats.iqr = iqr(beta)/1.3490;
    % Calculate bias & median of bias & RMSE if the true parameter values are provided
    if nargin>1
        stats.bias = stats.mean - beta0;
        stats.med_bias = stats.med - beta0;
        stats.rmse = sqrt(stats.bias.^2 + stats.std.^2);
    end
    % Calculate coverage & CI length & critical value if the standard error estimates are provided
    if nargin > 2
        alpha = [0.10 0.05 0.01]; % By default compute these three alpha level
        rp_arr = zeros(length(alpha),size(beta,2),size(se,3));
        CI_length_arr = zeros(length(alpha),size(beta,2),size(se,3));
        CI_oracle_length_arr = zeros(length(alpha),size(beta,2),size(se,3));
        cv_arr = zeros(length(alpha),size(beta,2),size(se,3));
        for i_se = 1:size(se,3)
            [rp, CI_length, CI_oracle_length, cv] = compute_size(beta,beta0,se(:,:,i_se),alpha);
            rp_arr(:,:,i_se) = rp;
            CI_length_arr(:,:,i_se) = CI_length;
            CI_oracle_length_arr(:,:,i_se) = CI_oracle_length;            
            cv_arr(:,:,i_se) = cv;
        end
        stats.rp = rp_arr;
        stats.length = CI_length_arr;
        stats.oracle_length = CI_oracle_length_arr;        
        stats.cv = cv_arr;
    end    
end

%%
% Auxiliary Function: compute_size
%   Description: 
%       This function evaluates how well confidence intervals capture the true parameter across
%       replications
%   Inputs:
%       beta: Estimates
%       beta0: True parameters
%       se: Standard errors estimates
%       alpha: Confidence level
%   Outputs:
%       rp: Coverage
%       CI_length: Average length of the confidence intervals
%       CI_oracle_length: Average length of the confidence intervals with the least favorable
%           critical value
%       cv: Empirical critical values

function [rp, CI_length, CI_oracle_length, cv] = compute_size(beta, beta0, se, alpha)    
    % Check the dimension
    n_e = size(beta,2);
    if n_e>size(se,2)
        se = repmat(se,1,size(beta,2));
    end
    % Calculate t-statistics
    t_stat = abs((beta-beta0)./se);
    rp = zeros(length(alpha),n_e);
    CI_length = zeros(length(alpha),n_e); 
    cv = zeros(length(alpha),n_e);
    for i_e=1:size(beta,2)
        rp(:,i_e) = (mean(bsxfun(@ge,t_stat(:,i_e),norminv(1-alpha/2))))';
        CI_length(:,i_e) = (mean(2*se(:,i_e)*norminv(1-alpha/2)))';
        cv(:,i_e) = quantile(t_stat(:,i_e),1-alpha)';
    end
    % Calculate the least favorable critical value
    CI_oracle_length = zeros(length(alpha),n_e);
    LFCV = max(cv,[],2);    
    for i_a=1:length(alpha)
        CI_oracle_length(i_a,:) = 2*mean(se)*LFCV(i_a);
    end
end
